{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A simple open-domain QA pipeline"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we will demonstrate how to build an open-domain QA pipeline using the unique components from fastRAG. \n",
    "\n",
    "We will use a simple `BM25Retriever` retriever, a neural re-ranker (based on SBERT)  `SentenceTransformersRanker` model and a `Fusion-in-Decoder` model to generate answers given the retrieved evidence. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build a local in-memory index and store sample documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-09-27 09:25:21,124] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "from haystack.document_stores import InMemoryDocumentStore\n",
    "\n",
    "document_store = InMemoryDocumentStore(use_gpu=False, use_bm25=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Updating BM25 representation...: 100%|████████████████████████████████████████████| 3/3 [00:00<00:00, 41255.45 docs/s]\n"
     ]
    }
   ],
   "source": [
    "from haystack.schema import Document\n",
    "\n",
    "# 3 example documents to index\n",
    "examples = [\n",
    "    \"There is a blue house on Oxford street\",\n",
    "    \"Paris is the capital of France\",\n",
    "    \"fastRAG had its first commit in 2022\"\n",
    "]\n",
    "\n",
    "documents = []\n",
    "for i, d in enumerate(examples):\n",
    "    documents.append(Document(content=d, id=i))\n",
    "\n",
    "document_store.write_documents(documents)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize the pipeline components"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the components we are going to use in our pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09/27/2023 09:25:31] {utils.py:130} INFO - Using devices: CUDA:0, CUDA:1, CUDA:2, CUDA:3, CUDA:4, CUDA:5, CUDA:6, CUDA:7 - Number of GPUs: 8\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "365f2ac5a2014be7b4231a8c995f1ef6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/791 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e0f67ed44184761b65e61dbec9a802a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading pytorch_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57a08588c32d4b95a836c418a3108049",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d2c1b1cdf10d47799cb20a40c9e3f939",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b39706c27b44e0284efbd6fdab57e01",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from haystack.nodes import BM25Retriever, SentenceTransformersRanker\n",
    "\n",
    "# define a BM25 retriever, ST re-ranker and FiD reader based on a local model\n",
    "retriever = BM25Retriever(document_store=document_store)\n",
    "reranker = SentenceTransformersRanker(model_name_or_path=\"cross-encoder/ms-marco-MiniLM-L-12-v2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09/27/2023 09:25:51] {utils.py:130} INFO - Using devices: CUDA:0 - Number of GPUs: 1\n",
      "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\n",
      "pip install xformers.\n",
      "The model 'FusionInDecoderForConditionalGeneration' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n",
      "[09/27/2023 09:25:53] {FiD.py:64} INFO - tokenizer max length is:512\n"
     ]
    }
   ],
   "source": [
    "from fastrag.prompters.invocation_layers import fid \n",
    "from haystack.nodes import PromptModel\n",
    "from haystack.nodes.prompt.prompt_template import PromptTemplate\n",
    "from haystack.nodes.prompt import PromptNode\n",
    "import torch\n",
    "\n",
    "PrompterModel = PromptModel(\n",
    "    model_name_or_path= \"Intel/fid_flan_t5_base_nq\",\n",
    "    use_gpu= True,\n",
    "    invocation_layer_class=fid.FiDHFLocalInvocationLayer,\n",
    "    model_kwargs= dict(\n",
    "        model_kwargs= dict(\n",
    "            device_map= {\"\": 0},\n",
    "            torch_dtype  = torch.bfloat16,\n",
    "            do_sample=False\n",
    "        ),\n",
    "        generation_kwargs=dict(\n",
    "            max_length=10\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "reader = PromptNode(\n",
    "    model_name_or_path= PrompterModel,\n",
    "    default_prompt_template=PromptTemplate(\"{query}\")\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from haystack import Pipeline\n",
    "\n",
    "p = Pipeline()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add the components in the right order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "p.add_node(component=retriever, name=\"Retriever\", inputs=[\"Query\"])\n",
    "p.add_node(component=reranker, name=\"Reranker\", inputs=[\"Retriever\"])\n",
    "p.add_node(component=reader, name=\"Reader\", inputs=[\"Reranker\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run a query through the pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = p.run(query=\"What is Paris?\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display the answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'the capital of France'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res['results'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "01262dede09baa68616418263efd26d33bafc508f82c218954990f624836b45d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
